import pandas as pd
import numpy as np
import jieba
from sklearn.model_selection import train_test_split


# 数据加载
def data_loader(path):
    return pd.read_excel(path, header=None, index=None)

# 分词
def cw(x):
    return list(jieba.cut(x))

# batch generation para
BATCH_SIZE = 64
BATCH_PATH = "../data/batch_data/20190626_"

# 读入文件路径
POS_PATH = "../data/pos.xls"
NEG_PATH = "../data/neg.xls"

# 加载正类和负类文档
pos = data_loader(POS_PATH)
neg = data_loader(NEG_PATH)

# 进行批量化分词
pos['words'] = pos[0].apply(cw)
neg['words'] = neg[0].apply(cw)

# 从pd中取出数据
pos_word_list = []
for item in pos['words']:
    pos_word_list.append(item)
neg_word_list = []
for item in neg['words']:
    neg_word_list.append(item)

# 分别进行存储
POS_CW_PATH = "../data/pos_cw.txt"
NEG_CW_PATH = "../data/neg_cw.txt"
with open(POS_CW_PATH, "w", encoding="utf8") as f:
    for item in pos_word_list:
        f.writelines("%s\n" % ("\t".join(item)))
with open(NEG_CW_PATH, "w", encoding="utf8") as f:
    for item in neg_word_list:
        f.writelines("%s\n" % ("\t".join(item)))

# 数据合并
all_data = pos_word_list + neg_word_list
all_labels = [1 for i in range(len(pos_word_list))] + [0 for i in range(len(neg_word_list))]

# 训练集和测试集生成
train_X, test_X, train_y, test_y = train_test_split(all_data, all_labels, test_size=0.2, random_state=10)

# 乱序训练集
index = [i for i in range(len(train_X))]
np.random.shuffle(index)
train_X = np.array(train_X)[index].tolist()
train_y = np.array(train_y)[index].tolist()

# 训练集batch划分
idx = 0
batch_idx = 0
tmp_batch_x = []
tmp_batch_y = []
fout = open(BATCH_PATH + str(batch_idx), "w", encoding="utf8")
while idx < len(train_X):
    fout.write("%s\t%s\n" % (train_y[idx], "\t".join(train_X[idx])))
    idx = idx + 1
    if idx % BATCH_SIZE == 0:
        fout.close()
        batch_idx = batch_idx + 1
        fout = open(BATCH_PATH + str(batch_idx), "w", encoding="utf8")
fout.close()

# 测试集处理
index = [i for i in range(len(test_X))]
np.random.shuffle(index)
test_X = np.array(test_X)[index].tolist()
test_y = np.array(test_y)[index].tolist()
fout = open(BATCH_PATH + "test", "w", encoding="utf8")
for idx in range(len(test_X)):
    fout.write("%s\t%s\n" % (test_y[idx], "\t".join(test_X[idx])))